# PLOT Supplementary FIGURE 6B
# Data = Longitudinal samples
# Exposure = Antimicrobial class (+ covariates)
# Outcome = Abundance of selected bacterial AMR genes
# Requires output of scripts 1, 2 & 3

### Data table  ----
data_for_LS_AM_class_ARG_model <- 
  l_pairs %>% 
  left_join(l_patients %>% select(pid, age_category, sex, tx), "pid") %>% 
  left_join(l_wcc, "pair_id") %>% 
  left_join(l_crp, "pair_id") %>% 
  left_join(l_news, "pair_id") %>% 
  left_join(table_of_pairs_with_AM_class_exposures, "pair_id") %>%
  left_join(l_argRA, "pair_id") %>%
  mutate(conditioning_day = collected.y)

### Exposures ----
names_of_all_exposures_in_LS_AM_class_ARG_model <- c(
  names_of_pair_AM_class_exposures_excluding_rarities,
  "age_category",
  "sex",
  "tx",
  "conditioning_day",
  "sample_separation",
  "new_low_wcc",
  "new_high_wcc",
  "new_high_crp",
  "news_increase")

### ARG models ----
# Note inclusion of baseline abundance as covariate

# > Beta-lactamases ----
multivariable_LS_AM_class_bla_model <- 
  lm(as.formula(paste0("log_bla_rpm_trunc_diff ~ ", 
                       paste0(names_of_all_exposures_in_LS_AM_class_ARG_model, collapse = " + "), 
                       " + log_bla_rpm_trunc.x")),
     data = data_for_LS_AM_class_ARG_model)

robust_multivariable_LS_AM_class_bla_model <- 
  coeftest(multivariable_LS_AM_class_bla_model, 
           cluster.vcov(multivariable_LS_AM_class_bla_model, data_for_LS_AM_class_ARG_model$pid))

robust_multivariable_LS_AM_class_bla_model_data_frame <- 
  data_frame(variable = robust_multivariable_LS_AM_class_bla_model[-1,2] %>% names(), 
             effect = robust_multivariable_LS_AM_class_bla_model[-1,1], 
             se = robust_multivariable_LS_AM_class_bla_model[-1,2], 
             ci = 1.96*robust_multivariable_LS_AM_class_bla_model[-1,2], 
             t = robust_multivariable_LS_AM_class_bla_model[-1,3], 
             p = robust_multivariable_LS_AM_class_bla_model[-1,4]) |> 
  mutate(effect_fold = 10^effect,
         upper = 10^(effect + ci),
         lower = 10^(effect - ci)) %>% 
  mutate(group = "bla")

# > Tetracycline (RPP)  ----
multivariable_LS_AM_class_tet_model <- 
  lm(as.formula(paste0("log_tet_rpm_trunc_diff ~ ", 
                       paste0(names_of_all_exposures_in_LS_AM_class_ARG_model, collapse = " + "), 
                       " + log_tet_rpm_trunc.x")),
     data = data_for_LS_AM_class_ARG_model)

robust_multivariable_LS_AM_class_tet_model <- 
  coeftest(multivariable_LS_AM_class_tet_model, 
           cluster.vcov(multivariable_LS_AM_class_tet_model, data_for_LS_AM_class_ARG_model$pid))

robust_multivariable_LS_AM_class_tet_model_data_frame <- 
  data_frame(variable = robust_multivariable_LS_AM_class_tet_model[-1,2] %>% names(), 
             effect = robust_multivariable_LS_AM_class_tet_model[-1,1], 
             se = robust_multivariable_LS_AM_class_tet_model[-1,2], 
             ci = 1.96*robust_multivariable_LS_AM_class_tet_model[-1,2], 
             t = robust_multivariable_LS_AM_class_tet_model[-1,3], 
             p = robust_multivariable_LS_AM_class_tet_model[-1,4]) |> 
  mutate(effect_fold = 10^effect,
         upper = 10^(effect + ci),
         lower = 10^(effect - ci)) %>% 
  mutate(group = "tet")

# > Aminoglycoside (AAC, ANT, APH) ----
multivariable_LS_AM_class_amg_model <- 
  lm(as.formula(paste0("log_amg_rpm_trunc_diff ~ ", 
                       paste0(names_of_all_exposures_in_LS_AM_class_ARG_model, collapse = " + "), 
                       " + log_amg_rpm_trunc.x")),
     data = data_for_LS_AM_class_ARG_model)

robust_multivariable_LS_AM_class_amg_model <- 
  coeftest(multivariable_LS_AM_class_amg_model, 
           cluster.vcov(multivariable_LS_AM_class_amg_model, data_for_LS_AM_class_ARG_model$pid))

robust_multivariable_LS_AM_class_amg_model_data_frame <- 
  data_frame(variable = robust_multivariable_LS_AM_class_amg_model[-1,2] %>% names(), 
             effect = robust_multivariable_LS_AM_class_amg_model[-1,1], 
             se = robust_multivariable_LS_AM_class_amg_model[-1,2], 
             ci = 1.96*robust_multivariable_LS_AM_class_amg_model[-1,2], 
             t = robust_multivariable_LS_AM_class_amg_model[-1,3], 
             p = robust_multivariable_LS_AM_class_amg_model[-1,4]) |> 
  mutate(effect_fold = 10^effect,
         upper = 10^(effect + ci),
         lower = 10^(effect - ci)) %>% 
  mutate(group = "amg")

# > Macrolide (mef & erm) ----
multivariable_LS_AM_class_mac_model <- 
  lm(as.formula(paste0("log_mac_rpm_trunc_diff ~ ", 
                       paste0(names_of_all_exposures_in_LS_AM_class_ARG_model, collapse = " + "), 
                       " + log_mac_rpm_trunc.x")),
     data = data_for_LS_AM_class_ARG_model)

robust_multivariable_LS_AM_class_mac_model <- 
  coeftest(multivariable_LS_AM_class_mac_model, 
           cluster.vcov(multivariable_LS_AM_class_mac_model, data_for_LS_AM_class_ARG_model$pid))

robust_multivariable_LS_AM_class_mac_model_data_frame <- 
  data_frame(variable = robust_multivariable_LS_AM_class_mac_model[-1,2] %>% names(), 
             effect = robust_multivariable_LS_AM_class_mac_model[-1,1], 
             se = robust_multivariable_LS_AM_class_mac_model[-1,2], 
             ci = 1.96*robust_multivariable_LS_AM_class_mac_model[-1,2], 
             t = robust_multivariable_LS_AM_class_mac_model[-1,3], 
             p = robust_multivariable_LS_AM_class_mac_model[-1,4]) |> 
  mutate(effect_fold = 10^effect,
         upper = 10^(effect + ci),
         lower = 10^(effect - ci)) %>% 
  mutate(group = "mac")

# > Glycopeptide (VanA)  ----
multivariable_LS_AM_class_van_model <- 
  lm(as.formula(paste0("log_van_rpm_trunc_diff ~ ", 
                       paste0(names_of_all_exposures_in_LS_AM_class_ARG_model, collapse = " + "), 
                       " + log_van_rpm_trunc.x")),
     data = data_for_LS_AM_class_ARG_model)

robust_multivariable_LS_AM_class_van_model <- 
  coeftest(multivariable_LS_AM_class_van_model, 
           cluster.vcov(multivariable_LS_AM_class_van_model, data_for_LS_AM_class_ARG_model$pid))

robust_multivariable_LS_AM_class_van_model_data_frame <- 
  data_frame(variable = robust_multivariable_LS_AM_class_van_model[-1,2] %>% names(), 
             effect = robust_multivariable_LS_AM_class_van_model[-1,1], 
             se = robust_multivariable_LS_AM_class_van_model[-1,2], 
             ci = 1.96*robust_multivariable_LS_AM_class_van_model[-1,2], 
             t = robust_multivariable_LS_AM_class_van_model[-1,3], 
             p = robust_multivariable_LS_AM_class_van_model[-1,4]) |> 
  mutate(effect_fold = 10^effect,
         upper = 10^(effect + ci),
         lower = 10^(effect - ci)) %>% 
  mutate(group = "van")

# Merge tables ----
combined_LS_AM_class_ARG_model_data_frame <-
  bind_rows(robust_multivariable_LS_AM_class_bla_model_data_frame, 
            robust_multivariable_LS_AM_class_tet_model_data_frame, 
            robust_multivariable_LS_AM_class_amg_model_data_frame, 
            robust_multivariable_LS_AM_class_mac_model_data_frame, 
            robust_multivariable_LS_AM_class_van_model_data_frame) %>%
  right_join(number_of_pairs_with_each_AM_class_exposure |> 
               full_join(data_frame(group = c("bla", "tet", "amg", "mac", "van")), by = character()) , 
             c("variable" = "drug_group_long", "group")) %>% 
  mutate(variable = str_replace_all(variable, "_", " "),        
         variable = str_to_sentence(variable),
         variable = fct_reorder(variable, desc(variable)),
         n = if_else(n < 6,"-", as.character(n)))

# Plot ----
ggplot() +
  geom_point(data = combined_LS_AM_class_ARG_model_data_frame %>% filter(group == "tet"), aes(y = variable, x = effect_fold), position = position_nudge(y = 0.2), colour = "#1b9e77") +
  geom_errorbarh(data = combined_LS_AM_class_ARG_model_data_frame %>% filter(group == "tet"), aes(y = variable, xmin = lower, xmax = upper), height = 0, position = position_nudge(y = 0.2), colour = "#1b9e77", size = 1) +
  geom_point(data = combined_LS_AM_class_ARG_model_data_frame %>% filter(group == "bla"), aes(y = variable, x = effect_fold), position = position_nudge(y = 0.1), colour = "#d95f02") +
  geom_errorbarh(data = combined_LS_AM_class_ARG_model_data_frame %>% filter(group == "bla"), aes(y = variable, xmin = lower, xmax = upper), height = 0, position = position_nudge(y = 0.1), colour = "#d95f02", size = 1) +
  geom_point(data = combined_LS_AM_class_ARG_model_data_frame %>% filter(group == "amg"), aes(y = variable, x = effect_fold), position = position_nudge(y = 0.0), colour = "#7570b3") +
  geom_errorbarh(data = combined_LS_AM_class_ARG_model_data_frame %>% filter(group == "amg"), aes(y = variable, xmin = lower, xmax = upper), height = 0, position = position_nudge(y = 0.0), colour = "#7570b3", size = 1) +
  geom_point(data = combined_LS_AM_class_ARG_model_data_frame %>% filter(group == "mac"), aes(y = variable, x = effect_fold), position = position_nudge(y = -0.1), colour = "#e7298a") +
  geom_errorbarh(data = combined_LS_AM_class_ARG_model_data_frame %>% filter(group == "mac"), aes(y = variable, xmin = lower, xmax = upper), height = 0, position = position_nudge(y = -0.1), colour = "#e7298a", size = 1) +
  geom_point(data = combined_LS_AM_class_ARG_model_data_frame %>% filter(group == "van"), aes(y = variable, x = effect_fold), position = position_nudge(y = -0.2), colour = "#66a61e") +
  geom_errorbarh(data = combined_LS_AM_class_ARG_model_data_frame %>% filter(group == "van"), aes(y = variable, xmin = lower, xmax = upper), height = 0, position = position_nudge(y = -0.2), colour = "#66a61e", size = 1) +
  geom_vline(xintercept = 1) +
  geom_text(data = combined_LS_AM_class_ARG_model_data_frame %>% filter(group == "tet"),
            aes(y = variable,
                x = 10^4.4,
                label = n)) +
  # ARG LABELS - not needed if on opposing panel
  # geom_label(aes(x = 10^4.5, y = 8.5, label = "Tetracycline (RPP) "), colour = "#1b9e77", fontface = "bold", hjust = "right") +
  # geom_label(aes(x = 10^4.5, y = 5, label = "Beta-lactamase"), colour = "#d95f02", fontface = "bold", hjust = "right") +
  # geom_label(aes(x = 10^4.5, y = 4.5, label = "Aminoglycoside (AAC, ANT, APH)"), colour = "#7570b3", fontface = "bold", hjust = "right") +
  # geom_label(aes(x = 10^4.5, y = 4, label = "Macrolide (mef & erm)"), colour = "#e7298a", fontface = "bold", hjust = "right") +
  # geom_label(aes(x = 10^4.5, y = 3.5, label = "Glycopeptide (VanA)"), colour = "#66a61e", fontface = "bold", hjust = "right") +
  scale_x_log10(breaks = c(1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4), label = scientific) +
  scale_y_discrete(position = "right") +
  coord_cartesian(xlim = c(10^-4.3, 10^4.3)) +
  labs(title = "Supplementary Figure 6B - Longitudinal", x = "Change in relative abundance", y = "") +
  theme(axis.text.y = element_text(size = 10, face = "bold", colour = "black"),
        axis.text.x = element_text(size = 10, face = "bold", colour = "black"),
        #panel.border = element_blank(),
        axis.line.x = element_blank(),
        axis.line = element_line(colour = "black"))

ggsave("plots/Supplementary Figure 6B - Antimicrobial class vs selected AMR genes in longitudinal arm.pdf", width = 148, height = 210, units = "mm")

write.csv(combined_LS_AM_class_ARG_model_data_frame |> 
            select("Variable" = variable, 
                   "Multivariable effect" = effect, 
                   "Multivariable std error" = se, 
                   "Multivariable p value" = p,
                   "Effect multiple" = effect_fold,
                   "Upper 95% CI" = upper,
                   "Lower 95% CI" = lower,
                   "ARG group" = group,
                   "Number exposed" = n), 
          "exports/Supplementary Figure 6B data - Antimicrobial class vs selected AMR genes in longitudinal arm.csv", row.names = F)

rm(#data_for_LS_AM_class_taxa_model,
   names_of_all_exposures_in_LS_AM_class_ARG_model,
   multivariable_LS_AM_class_bla_model,
   multivariable_LS_AM_class_tet_model,
   multivariable_LS_AM_class_amg_model,
   multivariable_LS_AM_class_mac_model,
   multivariable_LS_AM_class_van_model,
   robust_multivariable_LS_AM_class_bla_model,
   robust_multivariable_LS_AM_class_tet_model,
   robust_multivariable_LS_AM_class_amg_model,
   robust_multivariable_LS_AM_class_mac_model,
   robust_multivariable_LS_AM_class_van_model,
   robust_multivariable_LS_AM_class_bla_model_data_frame,
   robust_multivariable_LS_AM_class_tet_model_data_frame,
   robust_multivariable_LS_AM_class_amg_model_data_frame,
   robust_multivariable_LS_AM_class_mac_model_data_frame,
   robust_multivariable_LS_AM_class_van_model_data_frame)